package matrix;

/**
 * Created by root on 3/17/17.
 */
public class BasicMatrixMath {
    public final static int OPERATION_ADD = 1;
    public final static int OPERATION_SUB = 2;
    public final static int OPERATION_MUL = 3;

    /**
     * To be able to add two matrices, they must be of the same size
     *
     * @param matrixa
     * @param matrixb
     */
    public int[][] add(int[][] matrixa, int[][] matrixb) {
        int[][] result = new int[matrixa.length][matrixb[0].length];
        if (legalOperation(matrixa, matrixb, OPERATION_ADD)) {
            for (int i = 0; i < matrixa.length; i++) {
                for (int j = 0; j < matrixa[0].length; j++) {
                    result[i][j] = matrixa[i][j] + matrixb[i][j];
                }
            }
        }
        return result;
    }

    /**
     * To be able to substract two matrices, they must be of the same size
     *
     * @param matrixa
     * @param matrixb
     */
    public int[][] substract(int[][] matrixa, int[][] matrixb) {
        int[][] result1 = new int[matrixa.length][matrixb[0].length];
        if (legalOperation(matrixa, matrixb, OPERATION_SUB)) {

            for (int i = 0; i < matrixa.length; i++) {
                for (int j = 0; j < matrixa[0].length; j++) {
                    result1[i][j] = matrixa[i][j] - matrixb[i][j];
                }
            }
        }
        return result1;
    }

    /**
     *
     * @param matrixa
     * @param matrixb
     */
    public int[][] multiplication(int[][] matrixa, int[][] matrixb) {
        if (legalOperation(matrixa, matrixb, OPERATION_MUL)) {
            int[][] result2 = new int[matrixa.length][matrixb[0].length];
            for (int i = 0; i < matrixa.length; i++) {
                for (int j = 0; j < matrixa[0].length; j++) {
                    result2[i][j] = calculateSingleResult(matrixa, matrixb, i, j);
                }
            }
            return result2;
        } else {
            return null;
        }
    }

    private int calculateSingleResult(int[][] matrixa, int[][] matrixb, Integer row, int col) {
        int result = 0;
        for (int i = 0; i < matrixa[0].length; i++) {
            result += matrixa[row][i] * matrixb[i][col];
        }
        return result;
    }

    /**
     * @param matrixa
     * @param b
     */
    public int[][] multiplication(int[][] matrixa, int b) {
        int[][] result3 = new int[matrixa.length][matrixa[0].length];
        for (int i = 0; i < matrixa.length; i++) {
            for (int j = 0; j < matrixa[0].length; j++) {
                result3[i][j] = matrixa[i][j] * b;
            }
        }
        return result3;
    }

    /**
     * validate whether the parameters is valid parameters.
     *
     * @param a
     * @param b
     * @param type
     * @return
     */
    private boolean legalOperation(int[][] a, int[][] b, int type) {
        boolean legal = true;
        if (type == OPERATION_ADD || type == OPERATION_SUB) {
            if (a.length != b.length || a[0].length != b[0].length) {
                legal = false;
            }
        } else if (type == OPERATION_MUL) {
            if (a[0].length != b.length) {
                legal = false;
            }
        }
        return legal;
    }

    public static void main(String[] args) {
        int[][] a = new int[][] { { 1, 2 }, { 3, 4 } };
        int[][] b = new int[][] { { 7, 8 }, { 6, 5 } };
        BasicMatrixMath bmm = new BasicMatrixMath();
        System.out.println("addition two matrix");
        int[][] result = bmm.add(a, b);
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[0].length; j++) {
                System.out.print("\t" + result[i][j]);
            }
            System.out.println();
        }
        System.out.println("substract two matrix");
        int[][] result1 = bmm.substract(a, b);
        for (int i = 0; i < result1.length; i++) {
            for (int j = 0; j < result1[0].length; j++) {
                System.out.print("\t" + result1[i][j]);
            }
            System.out.println();
        }
        System.out.println("multiplex one matrix");
        result = bmm.multiplication(a, 3);
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[0].length; j++) {
                System.out.print("\t" + result[i][j]);
            }
            System.out.println();
        }
        System.out.println("multiplex two matrix");
        result = bmm.multiplication(a, b);
        for (int i = 0; i < result.length; i++) {
            for (int j = 0; j < result[0].length; j++) {
                System.out.print("\t" + result[i][j]);
            }
            System.out.println();
        }

    }

}
